{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建多层的 GRU 网络实现 MNIST 分类\n",
    "\n",
    "在上个例子中，我们已经理解了在 TensorFlow 中如何来实现 LSTM， 在本例子中来实现以下 GRU。和 LSTM 基本上一样，只是基本模型换成了 GRU。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n",
      "WARNING:tensorflow:From <ipython-input-1-fbc070ee2f3a>:15: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../data/MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../data/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../data/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../data/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "(10000, 10)\n",
      "(55000, 10)\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# 用tensorflow 导入数据\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True) \n",
    "\n",
    "# 看看咱们样本的数量\n",
    "print(mnist.test.labels.shape)\n",
    "print(mnist.train.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 一、首先设置好模型用到的各个超参数 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "input_size = 28      # 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素\n",
    "timestep_size = 28   # 时序持续长度为28，即每做一次预测，需要先输入28行\n",
    "hidden_size = 256    # 隐含层的数量\n",
    "layer_num = 2        # LSTM layer 的层数\n",
    "class_num = 10       # 最后输出分类类别数量，如果是回归预测的话应该是 1\n",
    "cell_type = \"block_gru\"   # gru 或者 block_gru\n",
    "\n",
    "X_input = tf.placeholder(tf.float32, [None, 784])\n",
    "y_input = tf.placeholder(tf.float32, [None, class_num])\n",
    "# 在训练和测试的时候，我们想用不同的 batch_size.所以采用占位符的方式\n",
    "batch_size = tf.placeholder(tf.int32, [])  # 注意类型必须为 tf.int32, batch_size = 128\n",
    "keep_prob = tf.placeholder(tf.float32, [])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 二、开始搭建 GRU 模型，和 LSTM 模型基本一致 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把784个点的字符信息还原成 28 * 28 的图片\n",
    "# 下面几个步骤是实现 RNN / gru 的关键\n",
    "\n",
    "# **步骤1：RNN 的输入shape = (batch_size, timestep_size, input_size) \n",
    "X = tf.reshape(X_input, [-1, 28, 28])\n",
    "\n",
    "# ** 步骤2：创建 gru 结构\n",
    "def gru_cell(cell_type, num_nodes, keep_prob):\n",
    "    assert(cell_type in [\"gru\", \"block_gru\"], \"Wrong cell type.\")\n",
    "    if cell_type == \"gru\":\n",
    "        cell = tf.contrib.rnn.GRUCell(num_nodes)\n",
    "    else:\n",
    "        cell = tf.contrib.rnn.GRUBlockCellV2(num_nodes)\n",
    "    cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)\n",
    "    return cell\n",
    "\n",
    "mgru_cell = tf.contrib.rnn.MultiRNNCell([gru_cell(cell_type, hidden_size, keep_prob) for _ in range(layer_num)], state_is_tuple = True)\n",
    "\n",
    "# **步骤3：用全零来初始化state\n",
    "init_state = mgru_cell.zero_state(batch_size, dtype=tf.float32)\n",
    "\n",
    "# **步骤4：方法一，调用 dynamic_rnn() 来让我们构建好的网络运行起来\n",
    "# ** 当 time_major==False 时， outputs.shape = [batch_size, timestep_size, hidden_size] \n",
    "# ** 所以，可以取 h_state = outputs[:, -1, :] 作为最后输出\n",
    "# ** state.shape = [layer_num, 2, batch_size, hidden_size], \n",
    "# ** 或者，可以取 h_state = state[-1][1] 作为最后输出\n",
    "# ** 最后输出维度是 [batch_size, hidden_size]\n",
    "# outputs, state = tf.nn.dynamic_rnn(mgru_cell, inputs=X, initial_state=init_state, time_major=False)\n",
    "# h_state = state[-1][1]\n",
    "\n",
    "# # *************** 为了更好的理解 gru 工作原理，我们把上面 步骤6 中的函数自己来实现 ***************\n",
    "# # 通过查看文档你会发现， RNNCell 都提供了一个 __call__()函数，我们可以用它来展开实现gru按时间步迭代。\n",
    "# # **步骤4：方法二，按时间步展开计算\n",
    "outputs = list()\n",
    "state = init_state\n",
    "with tf.variable_scope('RNN'):\n",
    "    for timestep in range(timestep_size):\n",
    "        (cell_output, state) = mgru_cell(X[:, timestep, :],state)\n",
    "        outputs.append(cell_output)\n",
    "h_state = outputs[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 三、最后设置 loss function 和 优化器，展开训练并完成测试 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 500, train cost=0.012643, acc=0.980000; test cost=0.011262, acc=0.964400; pass 11.324278831481934s\n",
      "step 1000, train cost=0.006675, acc=0.980000; test cost=0.007146, acc=0.976100; pass 10.350492238998413s\n",
      "step 1500, train cost=0.006741, acc=0.970000; test cost=0.005976, acc=0.981800; pass 10.223833799362183s\n",
      "step 2000, train cost=0.014265, acc=0.940000; test cost=0.005539, acc=0.983600; pass 10.121492385864258s\n",
      "step 2500, train cost=0.000602, acc=1.000000; test cost=0.004548, acc=0.986300; pass 10.094793796539307s\n",
      "step 3000, train cost=0.003602, acc=0.990000; test cost=0.004053, acc=0.989000; pass 10.224550247192383s\n",
      "step 3500, train cost=0.003195, acc=0.990000; test cost=0.004875, acc=0.985200; pass 10.053868293762207s\n",
      "step 4000, train cost=0.003112, acc=0.990000; test cost=0.004069, acc=0.989200; pass 10.328427076339722s\n",
      "step 4500, train cost=0.003034, acc=0.990000; test cost=0.004045, acc=0.988600; pass 10.144667387008667s\n",
      "step 5000, train cost=0.001717, acc=0.990000; test cost=0.004875, acc=0.987300; pass 10.119761943817139s\n"
     ]
    }
   ],
   "source": [
    "import time \n",
    "\n",
    "# 开始训练和测试\n",
    "W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)\n",
    "bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)\n",
    "y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)\n",
    "\n",
    "\n",
    "# 损失和评估函数\n",
    "cross_entropy = -tf.reduce_mean(y_input * tf.log(y_pre))\n",
    "train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y_input,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "time0 = time.time()\n",
    "for i in range(5000):\n",
    "    _batch_size=100\n",
    "    X_batch, y_batch = mnist.train.next_batch(batch_size=_batch_size)\n",
    "    cost, acc,  _ = sess.run([cross_entropy, accuracy, train_op], feed_dict={X_input: X_batch, y_input: y_batch, keep_prob: 0.5, batch_size: _batch_size})\n",
    "    if (i+1) % 500 == 0:\n",
    "        # 分 100 个batch 迭代\n",
    "        test_acc = 0.0\n",
    "        test_cost = 0.0\n",
    "        N = 100\n",
    "        for j in range(N):\n",
    "            X_batch, y_batch = mnist.test.next_batch(batch_size=_batch_size)\n",
    "            _cost, _acc = sess.run([cross_entropy, accuracy], feed_dict={X_input: X_batch, y_input: y_batch, keep_prob: 1.0, batch_size: _batch_size})\n",
    "            test_acc += _acc\n",
    "            test_cost += _cost\n",
    "        print(\"step {}, train cost={:.6f}, acc={:.6f}; test cost={:.6f}, acc={:.6f}; pass {}s\".format(i+1, cost, acc, test_cost/N, test_acc/N, time.time() - time0))\n",
    "        time0 = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "和前面 lstm 的例子比较一下，二者的准确率并没有太大的差别; 速度没有明显区别；显存占用也没有明显的区别。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四、可视化看看 GRU 的是怎么做分类的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "毕竟 GRU更多的是用来做时序相关的问题，要么是文本，要么是序列预测之类的，所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。在这里，我想通过可视化的方式来帮助大家理解 GRU 是怎么样一步一步地把图片正确的给分类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784) (5, 10)\n",
      "_outputs.shape = (28, 5, 256)\n",
      "arr_state.shape = (2, 5, 256)\n"
     ]
    }
   ],
   "source": [
    "# 手写的结果 shape\n",
    "_batch_size = 5\n",
    "X_batch, y_batch = mnist.test.next_batch(_batch_size)\n",
    "print(X_batch.shape, y_batch.shape)\n",
    "_outputs, _state = np.array(sess.run([outputs, state],feed_dict={\n",
    "            X_input: X_batch, y_input: y_batch, keep_prob: 1.0, batch_size: _batch_size}))\n",
    "print('_outputs.shape =', np.asarray(_outputs).shape)\n",
    "print('arr_state.shape =', np.asarray(_state).shape)\n",
    "# 可见 outputs.shape = [ batch_size, timestep_size, hidden_size]\n",
    "# state.shape = [layer_num, 2, batch_size, hidden_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "X3 = mnist.train.images[5]\n",
    "img3 = X3.reshape([28, 28])\n",
    "plt.imshow(img3, cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看看在分类的时候，一行一行地输入，分为各个类别的概率会是什么样子的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 1, 256)\n",
      "(28, 256)\n"
     ]
    }
   ],
   "source": [
    "X3.shape = [-1, 784]\n",
    "y_batch = mnist.train.labels[0]\n",
    "y_batch.shape = [-1, class_num]\n",
    "\n",
    "X3_outputs = np.array(sess.run(outputs, feed_dict={\n",
    "            X_input: X3, y_input: y_batch, keep_prob: 1.0, batch_size: 1}))\n",
    "print(X3_outputs.shape)\n",
    "X3_outputs.shape = [28, hidden_size]\n",
    "print(X3_outputs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABa9JREFUeJzt3UF22zYAQMGqJ/ExfTGfj1l10caVFJIwgc+ZrV8sBKI+QYiWHtu2/QVAy99XDwCA84k7QJC4AwSJO0CQuAMEiTtAkLgDBIk7QJC4AwSJO0CQuAMELRH3j8+vxAfgzPj/mHFMFR+fX5v5HefZ3F4x77M910vE/ad5Ua7Hcwb/Ju4nExjeccXJyAnwXqaI+5GDbu+l2ajH3GvGF96oMc12OQ1FU8SdsSrBnPEEeBfmfT3iDgeI3jhO5seIO7s8e+F5UV7HvPMPcSevdrK5y3s+VyjNg7gDBIk7TGil1eNKYz3iirv6jhB3gCBxB5jUkasFcQdu5S7bSOIOECTuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQY9tu8Vn6ADcipU7QJC4AwSJO0CQuAMEiTtA0PJxv8tXZrHfke+h5DVze8yo+Vs+7ldwMAOzE3eAIHEHCBJ3TmePG663RNyFAuDPLBF3AP7MbeNu6wAomyLuQgtwriniDsC5xB0gSNwBgsT9ZHvfO5jxfYcZxwS8R9yZyrOTiZMNvE/cvyEiwOrE/QdddcIY8bhHToBOnDBeOu4iAtxVOu685gTI3dzlmBd3ErxPwmirHWPiDhNaKSLMSdyBl5xs1iPusJhXfwvwk2NhrCNbQeIOLLefzGviDhAk7nCA1S6zEncY5NlWxxXbIKttvbwa74j3Hkb95fUV8/7YtmWeawDeZOUOECTuAEHiDhAk7gBB4g4QJO4/aKXb0EYyD+OY23FWu5VU3AGCxB0gSNwBgsQdIEjcAYLEHSBI3AGCxB0gSNwBgsQdIEjcAYLEHSBI3ElY7UOdYDRxZxcxhbmJO0BQOu5WlsBdpeMOcFfi/g37ycDqxB0gSNzhAq4OxzK34g6QJO4wodlWnrONh9fEnf9l6wDWJe7AkvYuPEYtWGYbzxRxf7VC3PuzI4/56t/ONJ53fveen13h1Vj3jPfKuZ1pfp+NZ9Tr4Ygr5m/UY17Rhce2TXPsAXCSKVbuAJxL3AGCxB0gSNwBgsQdIEjcvzHbLWw15nYcczunK54XcQcsaILEHSBI3AGCxB0gSNwBgsQdIEjcAYLEHSBI3AGCxB0gSNwBgsQdIEjcAYLEHSBI3AGCxB0gSNwBgsQdIEjcgUPu8g1OV3xb1ZHHFHeW4avg4H3iDhAk7gBB4r6DrQFgdsvH/Vlor4jwbOMZ5dn+95G98RFztOJe/V2OI8Z5bJvjBKBm+ZU7AL8Td4AgcQcIEneAIHEHCBJ3gCBxP5l7kMdZ8X71VZjbsa6YW3EHCBJ3gCBxBwgSd4AgcQcIEneAIHEHCBJ3gCBxBwgSd4AgcQcIEneAIHEHCBJ3gCBxBwgSd4AgcQcIEneAIHEHCBL3SfgOS+BM4k6eEyd3JO4AQeIOHOKqaJwjV53i/g2X8cDqHtumYQA1Vu4AQeIOECTuAEHiDhAk7gBB4g4QJO4nc3/8OP7+YBxzO9YVcyvuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO7kfXx+bR+fX9vV46gyt3MS90kIEPBfR7rw2DY9AaixcgcIEneAIHEHCBJ3gCBxBwgSd4Cg5ePu3vBx3Hs/jrkdy9wG4g7A78QdIEjcAYLEHSBI3AGCxB0gSNwBgsT9ZO6vBWYg7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQem4z/bZ6rONB+h6bJveANSkV+4AdyXuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4QJO4AQeIOECTuAEHiDhAk7gBB4g4Q9AvkiDgHvRWDuAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 28 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "h_W, h_bias = sess.run([W, bias], feed_dict={\n",
    "            X_input:X3, y_input: y_batch, keep_prob: 1.0, batch_size: 1})\n",
    "h_bias = h_bias.reshape([-1, 10])\n",
    "\n",
    "bar_index = range(class_num)\n",
    "for i in range(X3_outputs.shape[0]):\n",
    "    plt.subplot(7, 4, i+1)\n",
    "    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])\n",
    "    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))\n",
    "    plt.bar(bar_index, pro[0], width=0.2 , align='center')\n",
    "    plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上面的图中，为了更清楚地看到线条的变化，我把坐标都去了，每一行显示了 4 个图，共有 7 行，表示了一行一行读取过程中，模型对字符的识别。可以看到，在只看到前面的几行像素时，模型根本认不出来是什么字符，随着看到的像素越来越多，最后就基本确定了字符的类别"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
